# Copyright 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""deferred_resource converts blocking apiclient resource to deferred."""

import datetime
import functools
import httplib
import time
import threading
import traceback

from twisted.internet import defer, reactor, threads
from twisted.python import log as twistedLog
from twisted.python.threadpool import ThreadPool
import apiclient
import apiclient.discovery
import httplib2
import oauth2client


DEFAULT_RETRY_ATTEMPT_COUNT = 5
DEFAULT_RETRY_WAIT_SECONDS = 1


if httplib.FORBIDDEN not in oauth2client.client.REFRESH_STATUS_CODES:
  oauth2client.client.REFRESH_STATUS_CODES.append(httplib.FORBIDDEN)


class NotStartedError(Exception):
  pass


class DeferredResource(object):
  """Wraps an apiclient Resource, converts its methods to deferred.

  Accepts an apiclient.Resource, such as one generated by
  apiclient.discovery.build, and wraps all resource methods. When deferrred
  resource method is called, it schedules an actual rpc in a twisted thread pool
  and returns a Deferred.

  Has to be explicitly started and stopped. This can be done using "with"
  statement, see examples.

  Examples:
    Basic usage:

      @defer.inlineCallbacks
      def greet():
        # Asynchronously build a DeferredResource for my_greeting_service API.
        service = yield DeferredResource.build('my_greeting_service', 'v1')
        with service:
          response = yield res.api.greet('John')
          defer.returnValue(response)

    Authorization:

      with open(secret_key_filename, 'rb') as f:
        secret_key = f.read()
      AUTH_SCOPE = 'https://www.googleapis.com/auth/userinfo.email'
      creds = SignedJwtAssertionCredentials(service_account, secret_key,
           AUTH_SCOPE)
      service = yield DeferredResource.build(
          'my_greeting_service', 'v1', credentials=creds)

  Also DeferredResource retries requests on transient errors with exponential
  backoff.
  """

  class Api(object):
    """Dynamically creates resource methods."""
    def __init__(self, owner):
      self._methods = {}
      self._owner = owner

    def __getattr__(self, name):
      method = self._methods.get(name)
      if method is None:
        method = self._owner._twistify(name)
        if method is None:
          raise AttributeError('Resource does not have method %s' % name)
        self._methods[name] = method
      return method

  def __init__(
      self, resource, credentials=None, max_concurrent_requests=1,
      retry_wait_seconds=None, retry_attempt_count=None, verbose=False,
      log_prefix='', _pool=None):
    """Creates a DeferredResource.

    Args:
      resource (apiclient.Resource): a resource, such as one generated by
        apiclient.discovery.build.
      credentials (oauth2client.client.Credentials): credentials to use
        to make API requests.
      max_concurrent_requests (int): maximum number of concurrent requests.
        Defaults to 1.
      retry_wait_seconds (int, float): initial wait interval for request
        retrial. In seconds, defaults to 1.
      retry_attempt_count (int): number of attempts before giving up.
        Defaults to 5.
      verbose (bool): if True, log each request/response.
      log_prefix (str): prefix for log messages.
    """
    max_concurrent_requests = max_concurrent_requests or 1
    assert resource, 'resource not specified'
    if retry_wait_seconds is None:
      retry_wait_seconds = DEFAULT_RETRY_WAIT_SECONDS
    assert isinstance(retry_wait_seconds, (int, float))
    if retry_attempt_count is None:
      retry_attempt_count = DEFAULT_RETRY_ATTEMPT_COUNT
    assert isinstance(retry_attempt_count, int)

    self._pool = _pool or self._create_thread_pool(max_concurrent_requests)
    self._resource = resource
    self.credentials = credentials
    self.retry_wait_seconds = retry_wait_seconds
    self.retry_attempt_count = retry_attempt_count
    self.verbose = verbose
    self.log_prefix = log_prefix
    self.api = self.Api(self)
    self._th_local = threading.local()
    self.started = False

  @classmethod
  def _create_thread_pool(cls, max_concurrent_requests):
    return ThreadPool(minthreads=1, maxthreads=max_concurrent_requests)

  @classmethod
  def _create_async(
      cls, resource_factory, max_concurrent_requests=1, _pool=None, **kwargs):
    _pool = _pool or cls._create_thread_pool(max_concurrent_requests)
    result = defer.Deferred()

    def create_sync():
      # Stop the thread pool after creating DeferredResource.
      reactor.callFromThread(_pool.stop)
      try:
        assert resource_factory, 'resource_factory is not specified'
        res = resource_factory()
        def_res = cls(res, _pool=_pool, **kwargs)
        reactor.callFromThread(result.callback, def_res)
      except Exception as ex:
        reactor.callFromThread(result.errback, ex)

    _pool.start()
    _pool.callInThread(create_sync)
    return result

  # Yes, I've copied all these parameters because being explicit is good.
  @classmethod
  def build(
      cls, service_name, version, credentials=None, max_concurrent_requests=1,
      discoveryServiceUrl=apiclient.discovery.DISCOVERY_URI,
      developerKey=None, model=None,
      requestBuilder=apiclient.http.HttpRequest,
      retry_wait_seconds=None, retry_attempt_count=None, verbose=False,
      log_prefix=''):
    """Asynchronously builds a DeferredResource for a discoverable API.

    Asynchronously builds a resource by calling apiclient.discovery.build and
    wraps it with a DeferredResource.

    Args:
      serviceName: string, name of the service.
      version: string, the version of the service.
      credentials (oauth2client.client.Credentials): credentials to use
        to make API requests.
      max_concurrent_requests (int): maximum number of concurrent requests.
        Defaults to 1.
      discoveryServiceUrl: string, a URI Template that points to the location of
        the discovery service. It should have two parameters {api} and
        {apiVersion} that when filled in produce an absolute URI to the
        discovery document for that service.
      developerKey: string, key obtained from
        https://code.google.com/apis/console.
      model: apiclient.Model, converts to and from the wire format.
      requestBuilder: apiclient.http.HttpRequest, encapsulator for an HTTP
        request.
      retry_wait_seconds (int, float): initial wait interval for request
        retrial. In seconds, defaults to 1.
      retry_attempt_count (int): number of attempts before giving up.
        Defaults to 5.
      verbose (bool): if True, log each request/response.
      log_prefix (str): prefix for log messages.

    Returns:
      A DeferredResource as Deferred.
    """
    # Do not check arguments synchronously. Let the client check for exceptions
    # only in errback.
    def resource_factory():
      return apiclient.discovery.build(
          service_name,
          version,
          discoveryServiceUrl=discoveryServiceUrl,
          developerKey=developerKey,
          requestBuilder=requestBuilder,
      )

    return cls._create_async(
        resource_factory,
        credentials=credentials,
        max_concurrent_requests=max_concurrent_requests,
        retry_wait_seconds=retry_wait_seconds,
        retry_attempt_count=retry_attempt_count,
        verbose=verbose,
        log_prefix=log_prefix,
    )

  def log(self, message):
    twistedLog.msg('%s%s' % (self.log_prefix, message))

  def start(self):
    self._pool.start()
    self.started = True

  def stop(self):
    self.started = False
    self._pool.stop()

  def __enter__(self):
    self.start()
    return self

  def __exit__(self, *args, **kwrags):
    self.stop()

  @defer.inlineCallbacks
  def _retry(self, method_name, call):
    """Retries |call| on transient errors and access token expiration.

    Args:
      method_name (str): name of the remote method, for logging.
      call (func() -> any): a function that makes an RPC call and returns
        result.
    """
    attempts = self.retry_attempt_count
    wait = self.retry_wait_seconds
    while attempts > 0:
      attempts -= 1
      try:
        if not self.started:
          raise NotStartedError('DeferredResource is not started')
        res = yield threads.deferToThreadPool(reactor, self._pool, call)
        defer.returnValue(res)
      except Exception as ex:
        if not self.started:
          raise ex
        if isinstance(ex, apiclient.errors.HttpError):
          status = ex.resp.status if ex.resp else None
          if status >= 500 and attempts > 0:
            self.log('Transient error while calling %s. '
                     'Will retry in %d seconds.' % (method_name, wait))
            # TODO(nodir), optimize: stop waiting if the resource is stopped.
            yield sleep(wait)
            if not self.started:
              raise ex
            wait = min(wait * 2, 30)
            continue
        self.log('RPC "%s" failed: %s'% (method_name, traceback.format_exc()))
        raise ex

  def _log_request(self, method_name, args, kwargs):
    arg_str_list = map(repr, args)
    arg_str_list += ['%s=%r' % (k, v) for k, v in kwargs.iteritems()]
    self.log('Request %s(%s)' % (method_name, ', '.join(arg_str_list)))

  def _twistify(self, method_name):
    """Wraps a resource method by name."""
    method = getattr(self._resource, method_name, None)
    if method is None:
      return None

    @functools.wraps(method)
    def twistified(*args, **kwargs):
      def single_call():
        if getattr(self._th_local, 'http', None) is None:
          self._th_local.credentials = None
          self._th_local.http = httplib2.Http()
          if self.credentials:
            self._th_local.credentials = self.credentials.from_json(
                self.credentials.to_json())
            self._th_local.http = self._th_local.credentials.authorize(
                self._th_local.http)
        elif getattr(self._th_local.credentials, 'token_expiry', None):
          # Check token_expiry more aggressively:
          # refresh if it expires in <= 5 min.
          expiry = (
              self._th_local.credentials.token_expiry -
              datetime.timedelta(minutes=5))
          if expiry <= datetime.datetime.utcnow():
            self._th_local.credentials.refresh(self._th_local.http)

        if self.verbose:
          self._log_request(method_name, args, kwargs)
        response = method(*args, **kwargs).execute(self._th_local.http)
        if self.verbose:
          self.log('Reponse: %s' % response)
        return response
      return self._retry(method_name, single_call)
    return twistified


def sleep(secs):
  d = defer.Deferred()
  reactor.callLater(secs, d.callback, None)
  return d
